package com.virjar.dungproxy.server.proxyservice.client.decoder;
import com.virjar.dungproxy.server.entity.Proxy;
import com.virjar.dungproxy.server.proxyservice.client.exception.ServerChannelInactiveException;
import com.virjar.dungproxy.server.proxyservice.client.listener.ResponseListener;
import com.virjar.dungproxy.server.proxyservice.common.Constants;
import com.virjar.dungproxy.server.proxyservice.common.util.NetworkUtil;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.ByteBufAllocator;
import io.netty.channel.ChannelHandler;
import io.netty.channel.ChannelHandlerContext;
import io.netty.handler.codec.ReplayingDecoder;
import io.netty.handler.codec.TooLongFrameException;
import io.netty.handler.codec.http.DefaultHttpResponse;
import io.netty.handler.codec.http.HttpConstants;
import io.netty.handler.codec.http.HttpHeaders;
import io.netty.handler.codec.http.HttpMessage;
import io.netty.handler.codec.http.HttpResponse;
import io.netty.handler.codec.http.HttpResponseStatus;
import io.netty.handler.codec.http.HttpVersion;
import io.netty.util.AttributeKey;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.util.List;
import static com.virjar.dungproxy.server.proxyservice.client.decoder.HttpBodyDecoder.State.NO_CONTENT;
public class HttpHeaderDecoder extends ReplayingDecoder<HttpHeaderDecoder.State> {
private static final Logger LOGGER = LoggerFactory.getLogger(HttpHeaderDecoder.class);
public final static AttributeKey<Boolean> KEEP_ALIVE = AttributeKey.valueOf("keep-alive");
private static final String CR_LF = String.valueOf(new char[]{HttpConstants.CR, HttpConstants.LF});
protected static final ThreadLocal<StringBuilder> BUILDERS = new ThreadLocal<StringBuilder>() {
@Override
public StringBuilder initialValue() {
return new StringBuilder(512);
}
@Override
public StringBuilder get() {
StringBuilder sb = super.get();
sb.setLength(0);
return sb;
}
};
private ResponseListener listener;
private final int maxInitialLineLength;
private final int maxHeaderSize;
private HttpResponse message;
private int headerSize;
private ByteBuf initialLineBuf;
private boolean decodeFinished = false;
private boolean shouldKeepConnectionAlive = false;
private ByteBuf headerBuf;
private boolean retry;
private Proxy proxy;
enum State {
SKIP_CONTROL_CHARS,
READ_INITIAL,
READ_HEADER,
READ_VARIABLE_LENGTH_CONTENT,
READ_FIXED_LENGTH_CONTENT,
READ_CHUNK
}
/**
* Creates a new instance with the default
* {@code maxInitialLineLength (4096}}, {@code maxHeaderSize (8192)}
*/
public HttpHeaderDecoder(ResponseListener listener, Proxy proxy, boolean retry) {
this(4096, 8192, listener, retry, proxy);
}
/**
* Creates a new instance with the specified parameters.
*/
public HttpHeaderDecoder(
int maxInitialLineLength,
int maxHeaderSize,
ResponseListener listener,
boolean retry, Proxy proxy) {
super(State.SKIP_CONTROL_CHARS);
this.retry = retry;
this.proxy = proxy;
if (maxInitialLineLength <= 0) {
throw new IllegalArgumentException(
"maxInitialLineLength must be a positive integer: " +
maxInitialLineLength
);
}
if (maxHeaderSize <= 0) {
throw new IllegalArgumentException(
"maxHeaderSize must be a positive integer: " +
maxHeaderSize
);
}
this.maxInitialLineLength = maxInitialLineLength;
this.maxHeaderSize = maxHeaderSize;
this.listener = listener;
}
@Override
protected void decode(ChannelHandlerContext ctx, ByteBuf in, List<Object> out) throws Exception {
switch (state()) {
case SKIP_CONTROL_CHARS:
try {
skipControlCharacters(in);
checkpoint(State.READ_INITIAL);
} finally {
checkpoint();
}
case READ_INITIAL:
try {
String[] initialLine = splitInitialLine(readLine(in, maxInitialLineLength));
if (initialLine.length < 3) {
// Invalid initial line - ignore
finish(ctx, new BadMessageHandler(listener, new IllegalStateException("Invalid initial line")), in);
return;
}
message = createHeader(initialLine);
checkpoint(State.READ_HEADER);
if (in.readerIndex() == in.writerIndex()) {
initialLineBuf = in.copy(0, in.writerIndex());
initialLineBuf.readerIndex(0);
}
} catch (Exception e) {
// bad message, initial line too long e.g.
finish(ctx, new BadMessageHandler(listener, e), in);
return;
}
case READ_HEADER:
try {
State nextState = readHeaders(in);
LOGGER.debug("Header received [{}]", message);
boolean needRetry = listener.onHeaderReceived(message);
if (needRetry && retry) {
// return to retry
ctx.pipeline().remove(this);
in.readerIndex(in.writerIndex());
NetworkUtil.releaseMsgCompletely(initialLineBuf);
NetworkUtil.releaseMsgCompletely(headerBuf);
return;
}
checkpoint(nextState);
headerBuf = createHeaderBuf(in, Constants.PROXY_ROUTER_KEY, String.valueOf(proxy.getId()), in.readerIndex());
if (nextState == State.SKIP_CONTROL_CHARS) {
// No content is expected
finish(ctx, new HttpBodyNoContentDecoder(listener, NO_CONTENT, headerBuf, headerBuf.readerIndex()), in);
return;
}
long contentLength = HttpHeaders.getContentLength(message, -1);
// Hand over the unhandled data to next handler
switch (nextState) {
case READ_CHUNK: {
LOGGER.debug("Decode successfully, body type is CHUNKED");
finish(ctx, new HttpBodyChunkDecoder(listener, headerBuf, headerBuf.readerIndex()), in);
break;
}
case READ_FIXED_LENGTH_CONTENT: {
LOGGER.debug("Decode successfully, body type is FIX_LENGTH, length is {}", contentLength);
finish(ctx, new HttpBodyFixedLengthDecoder(listener, headerBuf, contentLength, headerBuf.readerIndex()), in);
break;
}
case READ_VARIABLE_LENGTH_CONTENT: {
LOGGER.debug("Decode successfully, body type is VARI_LENGTH");
finish(ctx, new HttpBodyVariLengthDecoder(listener, headerBuf, headerBuf.readerIndex()), in);
break;
}
default: {
NetworkUtil.releaseMsgCompletely(headerBuf);
throw new IllegalStateException("Unexpected state: " + nextState);
}
}
// We return here, this forces decode to be called again where we will decode the content
} catch (Exception e) {
if (headerBuf != null) {
NetworkUtil.releaseMsgCompletely(headerBuf);
}
finish(ctx, new BadMessageHandler(listener, e), in);
}
break;
}
}
/**
* Return a new ByteBuf which contains the specific header
*
* @param src original byte buf
* @param key header key
* @param value header val
* @param headerEndIndex the index of the end of last header
* @return ByteBuf contains the specific header
*/
private ByteBuf createHeaderBuf(ByteBuf src, String key, String value, int headerEndIndex) {
byte[] insertion = (key + ": " + value + CR_LF).getBytes();
ByteBuf newBuf;
int initBufLen = 0;
if (initialLineBuf != null) {
initBufLen = initialLineBuf.writerIndex();
newBuf = ByteBufAllocator.DEFAULT.directBuffer(src.writerIndex() + initialLineBuf.writerIndex() + insertion.length);
newBuf.writeBytes(initialLineBuf);
initialLineBuf.release();
initialLineBuf = null;
} else {
newBuf = ByteBufAllocator.DEFAULT.directBuffer(src.writerIndex() + insertion.length);
}
newBuf.writeBytes(src, 0, headerEndIndex - 2);
newBuf.writeBytes(insertion);
newBuf.writeBytes(src, headerEndIndex - 2, src.writerIndex() - headerEndIndex + 2);
newBuf.readerIndex(headerEndIndex + insertion.length + initBufLen);
return newBuf;
}
/**
* It's called once the http headers is decoded.
* The decode task will be ended and hand over to {@link HttpBodyDecoder} handler
*/
private void finish(
ChannelHandlerContext ctx,
ChannelHandler nextHandler,
ByteBuf cumulation
) {
// Set reader index to the writer index in order to release it in ByteToMessageDecoder
ctx.channel().attr(KEEP_ALIVE).set(shouldKeepConnectionAlive);
cumulation.readerIndex(cumulation.writerIndex());
decodeFinished = true;
ctx.pipeline().remove(this);
ctx.pipeline().addLast(nextHandler);
}
private static String[] splitInitialLine(StringBuilder sb) {
int protocolStart;
int protocolEnd;
int codeStart;
int codeEnd;
int textStart;
int textEnd;
protocolStart = findNonWhitespace(sb, 0);
protocolEnd = findWhitespace(sb, protocolStart);
codeStart = findNonWhitespace(sb, protocolEnd);
codeEnd = findWhitespace(sb, codeStart);
textStart = findNonWhitespace(sb, codeEnd);
textEnd = findLastNonWhitespace(sb);
return new String[]{
sb.substring(protocolStart, protocolEnd),
sb.substring(codeStart, codeEnd),
textStart < textEnd ? sb.substring(textStart, textEnd) : ""};
}
private StringBuilder readHeader(ByteBuf buf) {
StringBuilder sb = BUILDERS.get();
int headerSize = this.headerSize;
loop:
for (; ; ) {
char nextByte = (char) buf.readByte();
headerSize++;
switch (nextByte) {
case HttpConstants.CR:
nextByte = (char) buf.readByte();
headerSize++;
if (nextByte == HttpConstants.LF) {
break loop;
}
break;
case HttpConstants.LF:
break loop;
}
// Abort decoding if the message part is too large
if (headerSize >= maxHeaderSize) {
throw new TooLongFrameException("HTTP message is larger than " + maxHeaderSize + " bytes.");
}
sb.append(nextByte);
}
this.headerSize = headerSize;
return sb;
}
/**
* Create a response message with http initial line
*
* @param initialLine http initial line
* @return message Object
* @throws Exception
*/
private HttpResponse createHeader(String[] initialLine) throws Exception {
return new DefaultHttpResponse(
HttpVersion.valueOf(initialLine[0]),
new HttpResponseStatus(Integer.valueOf(initialLine[1]), initialLine[2]));
}
private State readHeaders(ByteBuf buf) {
headerSize = 0;
final HttpMessage message = this.message;
final HttpHeaders headers = message.headers();
StringBuilder line = readHeader(buf);
String hName = null;
String hVal = null;
if (line.length() > 0) {
headers.clear();
do {
char firstChar = line.charAt(0);
if (hName != null && (firstChar == ' ' || firstChar == '\t')) {
hVal = hVal + ' ' + line.toString().trim();
} else {
if (hName != null) {
headers.add(hName, hVal);
}
String[] currentHeader = splitHeader(line);
hName = currentHeader[0];
hVal = currentHeader[1];
}
line = readHeader(buf);
} while (line.length() > 0);
// last message
if (hName != null) {
headers.add(hName, hVal);
}
}
State nextState;
if (isContentAlwaysEmpty(message)) {
HttpHeaders.removeTransferEncodingChunked(message);
nextState = State.SKIP_CONTROL_CHARS;
} else if (HttpHeaders.isTransferEncodingChunked(message)) {
nextState = State.READ_CHUNK;
} else if (HttpHeaders.getContentLength(message, -1) >= 0) {
nextState = State.READ_FIXED_LENGTH_CONTENT;
} else {
nextState = State.READ_VARIABLE_LENGTH_CONTENT;
}
String connectionHeader = headers.get("Connection");
shouldKeepConnectionAlive = nextState != State.READ_VARIABLE_LENGTH_CONTENT && !message.getProtocolVersion().equals(HttpVersion.HTTP_1_0) && ((connectionHeader == null) || (connectionHeader.equalsIgnoreCase("keep-alive")));
LOGGER.debug("shouldKeepConnectionAlive:{}", shouldKeepConnectionAlive);
return nextState;
}
private static boolean isContentAlwaysEmpty(HttpMessage msg) {
if (msg instanceof HttpResponse) {
HttpResponse res = (HttpResponse) msg;
int code = res.getStatus().code();
// handle 1xx
if (code >= 100 && code < 200) {
return !(code == 101 && !res.headers().contains(HttpHeaders.Names.SEC_WEBSOCKET_ACCEPT));
}
switch (code) {
case 204:
case 205:
case 304:
return true;
}
}
return false;
}
private static String[] splitHeader(StringBuilder sb) {
final int length = sb.length();
int nameStart;
int nameEnd;
int colonEnd;
int valueStart;
int valueEnd;
nameStart = findNonWhitespace(sb, 0);
for (nameEnd = nameStart; nameEnd < length; nameEnd++) {
char ch = sb.charAt(nameEnd);
if (ch == ':' || Character.isWhitespace(ch)) {
break;
}
}
for (colonEnd = nameEnd; colonEnd < length; colonEnd++) {
if (sb.charAt(colonEnd) == ':') {
colonEnd++;
break;
}
}
valueStart = findNonWhitespace(sb, colonEnd);
if (valueStart == length) {
return new String[]{sb.substring(nameStart, nameEnd), ""};
}
valueEnd = findLastNonWhitespace(sb);
return new String[]{
sb.substring(nameStart, nameEnd),
sb.substring(valueStart, valueEnd)
};
}
/**
* read a line from {@code buf} util a <i>CR LF</i> or <i>LF</i> is encountered
*
* @param buf source buf
* @param maxLineLength max line length
* @return a {@code StringBuilder} represents the line text
*/
private static StringBuilder readLine(ByteBuf buf, int maxLineLength) {
StringBuilder sb = BUILDERS.get();
int lineLength = 0;
while (true) {
byte nextByte = buf.readByte();
if (nextByte == HttpConstants.CR) {
nextByte = buf.readByte();
if (nextByte == HttpConstants.LF) {
return sb;
}
} else if (nextByte == HttpConstants.LF) {
return sb;
} else {
if (lineLength >= maxLineLength) {
throw new TooLongFrameException(
"An HTTP line is larger than " + maxLineLength +
" bytes."
);
}
lineLength++;
sb.append((char) nextByte);
}
}
}
private static void skipControlCharacters(ByteBuf buf) {
for (; ; ) {
char c = (char) buf.readUnsignedByte();
if (!Character.isISOControl(c) && !Character.isWhitespace(c)) {
buf.readerIndex(buf.readerIndex() - 1);
break;
}
}
}
/**
* find the first non whitespace char of {@code cs} from {@code start} on
*
* @param cs target char sequence
* @param start index to start
* @return index of the first non whitespace char
*/
private static int findNonWhitespace(CharSequence cs, int start) {
int index;
for (index = start; index < cs.length(); index++) {
if (!Character.isWhitespace(cs.charAt(index))) {
break;
}
}
return index;
}
/**
* find the first whitespace char of {@code cs} from {@code start} on
*
* @param cs target char sequence
* @param start index to start
* @return index of the first whitespace char
*/
private static int findWhitespace(CharSequence cs, int start) {
int index;
for (index = start; index < cs.length(); index++) {
if (Character.isWhitespace(cs.charAt(index))) {
break;
}
}
return index;
}
/**
* find the last non whitespace char of {@code cs}
*
* @param cs target char sequence
* @return index of the last non whitespace char
*/
private static int findLastNonWhitespace(CharSequence cs) {
int index;
for (index = cs.length(); index > 0; index--) {
if (!Character.isWhitespace(cs.charAt(index - 1))) {
break;
}
}
return index;
}
@Override
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
LOGGER.debug("Exception occurred while decoding response headers", cause);
NetworkUtil.releaseMsgCompletely(initialLineBuf);
if (!decodeFinished) {
NetworkUtil.releaseMsgCompletely(headerBuf);
}
listener.onThrowable(getClass().getName() + " exceptionCaught", cause);
}
@Override
public void channelInactive(ChannelHandlerContext ctx) throws Exception {
NetworkUtil.releaseMsgCompletely(initialLineBuf);
if (!decodeFinished) {
String msg = "Server channel inactive while decoding response header";
LOGGER.debug(msg);
NetworkUtil.releaseMsgCompletely(headerBuf);
listener.onThrowable(msg, ServerChannelInactiveException.INSTANCE);
}
super.channelInactive(ctx);
}
}